{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overview\n",
    "\n",
    "In this project, I will use an Alternating Least Squares (ALS) algorithm with Spark APIs to predict the ratings for the movies in [MovieLens Datasets](https://grouplens.org/datasets/movielens/latest/)\n",
    "\n",
    "## [Recommender system](https://en.wikipedia.org/wiki/Recommender_system)\n",
    "A recommendation system is basically an information filtering system that seeks to predict the \"rating\" or \"preference\" a user would give to an item. It is widely used in different internet / online business such as Amazon, Netflix, Spotify, or social media like Facebook and Youtube. By using recommender systems, those companies are able to provide better or more suited products/services/contents that are personalized to a user based on his/her historical consumer behaviors\n",
    "\n",
    "Recommender systems typically produce a list of recommendations through collaborative filtering or through content-based filtering\n",
    "\n",
    "This project will focus on collaborative filtering and use Alternating Least Squares (ALS) algorithm to make movie predictions\n",
    "\n",
    "\n",
    "## [Alternating Least Squares](https://endymecy.gitbooks.io/spark-ml-source-analysis/content/%E6%8E%A8%E8%8D%90/papers/Large-scale%20Parallel%20Collaborative%20Filtering%20the%20Netflix%20Prize.pdf)\n",
    "ALS is one of the low rank matrix approximation algorithms for collaborative filtering. ALS decomposes user-item matrix into two low rank matrixes: user matrix and item matrix. In collaborative filtering, users and products are described by a small set of latent factors that can be used to predict missing entries. And ALS algorithm learns these latent factors by matrix factorization\n",
    "\n",
    "\n",
    "## Data Sets\n",
    "I use [MovieLens Datasets](https://grouplens.org/datasets/movielens/latest/).\n",
    "This dataset (ml-latest.zip) describes 5-star rating and free-text tagging activity from [MovieLens](http://movielens.org), a movie recommendation service. It contains 27753444 ratings and 1108997 tag applications across 58098 movies. These data were created by 283228 users between January 09, 1995 and September 26, 2018. This dataset was generated on September 26, 2018.\n",
    "\n",
    "Users were selected at random for inclusion. All selected users had rated at least 1 movies. No demographic information is included. Each user is represented by an id, and no other information is provided.\n",
    "\n",
    "The data are contained in the files `genome-scores.csv`, `genome-tags.csv`, `links.csv`, `movies.csv`, `ratings.csv` and `tags.csv`.\n",
    "\n",
    "## Project Content\n",
    "1. Load Data\n",
    "2. Spark SQL and OLAP\n",
    "3. Spark ALS based approach for training model\n",
    "4. ALS Model Selection and Evaluation\n",
    "5. Model testing\n",
    "6. Make movie recommendation to myself"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "# spark imports\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import UserDefinedFunction, explode, desc\n",
    "from pyspark.sql.types import StringType, ArrayType\n",
    "from pyspark.mllib.recommendation import ALS\n",
    "\n",
    "# data science imports\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# visualization imports\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# spark config\n",
    "spark = SparkSession \\\n",
    "    .builder \\\n",
    "    .appName(\"movie recommendation\") \\\n",
    "    .config(\"spark.driver.maxResultSize\", \"96g\") \\\n",
    "    .config(\"spark.driver.memory\", \"96g\") \\\n",
    "    .config(\"spark.executor.memory\", \"8g\") \\\n",
    "    .config(\"spark.master\", \"local[12]\") \\\n",
    "    .getOrCreate()\n",
    "# get spark context\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path config\n",
    "data_path = os.path.join(os.environ['DATA_PATH'], 'MovieLens')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "movies = spark.read.load(os.path.join(data_path, 'movies.csv'), format='csv', header=True, inferSchema=True)\n",
    "ratings = spark.read.load(os.path.join(data_path, 'ratings.csv'), format='csv', header=True, inferSchema=True)\n",
    "links = spark.read.load(os.path.join(data_path, 'links.csv'), format='csv', header=True, inferSchema=True)\n",
    "tags = spark.read.load(os.path.join(data_path, 'tags.csv'), format='csv', header=True, inferSchema=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### basic inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+--------------------+--------------------+\n",
      "|movieId|               title|              genres|\n",
      "+-------+--------------------+--------------------+\n",
      "|      1|    Toy Story (1995)|Adventure|Animati...|\n",
      "|      2|      Jumanji (1995)|Adventure|Childre...|\n",
      "|      3|Grumpier Old Men ...|      Comedy|Romance|\n",
      "+-------+--------------------+--------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "movies.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+-------+------+----------+\n",
      "|userId|movieId|rating| timestamp|\n",
      "+------+-------+------+----------+\n",
      "|     1|    307|   3.5|1256677221|\n",
      "|     1|    481|   3.5|1256677456|\n",
      "|     1|   1091|   1.5|1256677471|\n",
      "+------+-------+------+----------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "ratings.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------+------+\n",
      "|movieId|imdbId|tmdbId|\n",
      "+-------+------+------+\n",
      "|      1|114709|   862|\n",
      "|      2|113497|  8844|\n",
      "|      3|113228| 15602|\n",
      "+-------+------+------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "links.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+-------+--------+----------+\n",
      "|userId|movieId|     tag| timestamp|\n",
      "+------+-------+--------+----------+\n",
      "|    14|    110|    epic|1443148538|\n",
      "|    14|    110|Medieval|1443148532|\n",
      "|    14|    260|  sci-fi|1442169410|\n",
      "+------+-------+--------+----------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tags.show(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Spark SQL and OLAP\n",
    "\n",
    "Below are the questions I'd like to ask:\n",
    "1. What are the ratings?\n",
    "2. What is minimum number of ratings per user and minimum number of ratings per movie?\n",
    "3. How many movies are rated by only one user?\n",
    "4. What is the total number of users in the data sets?\n",
    "5. What is the total number of movies in the data sets?\n",
    "6. How many movies are rated by users? List movies not rated yet?\n",
    "7. List all movie genres\n",
    "8. Find out the number of movies for each category\n",
    "9. Calculate the total rating count for every movie\n",
    "10. Get a count plot for each rating"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What are the ratings?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Distinct values of ratings:\n",
      "[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]\n"
     ]
    }
   ],
   "source": [
    "print('Distinct values of ratings:')\n",
    "print(sorted(ratings.select('rating').distinct().rdd.map(lambda r: r[0]).collect()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What is minimum number of ratings per user and minimum number of ratings per movie?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For the users that rated movies and the movies that were rated:\n",
      "Minimum number of ratings per user is 1\n",
      "Minimum number of ratings per movie is 1\n"
     ]
    }
   ],
   "source": [
    "tmp1 = ratings.groupBy(\"userID\").count().toPandas()['count'].min()\n",
    "tmp2 = ratings.groupBy(\"movieId\").count().toPandas()['count'].min()\n",
    "print('For the users that rated movies and the movies that were rated:')\n",
    "print('Minimum number of ratings per user is {}'.format(tmp1))\n",
    "print('Minimum number of ratings per movie is {}'.format(tmp2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How many movies are rated by only one user?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10155 out of 53889 movies are rated by only one user\n"
     ]
    }
   ],
   "source": [
    "tmp1 = sum(ratings.groupBy(\"movieId\").count().toPandas()['count'] == 1)\n",
    "tmp2 = ratings.select('movieId').distinct().count()\n",
    "print('{} out of {} movies are rated by only one user'.format(tmp1, tmp2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What is the total number of users in the data sets?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have a total of 283228 distinct users in the data sets\n"
     ]
    }
   ],
   "source": [
    "tmp = ratings.select('userID').distinct().count()\n",
    "print('We have a total of {} distinct users in the data sets'.format(tmp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What is the total number of movies in the data sets?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have a total of 58098 distinct movies in the data sets\n"
     ]
    }
   ],
   "source": [
    "tmp = movies.select('movieID').distinct().count()\n",
    "print('We have a total of {} distinct movies in the data sets'.format(tmp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How many movies are rated by users? List movies not rated yet?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have a total of 53889 distinct movies that are rated by users in ratings table\n",
      "We have 4209 movies that are not rated yet\n"
     ]
    }
   ],
   "source": [
    "tmp1 = movies.select('movieID').distinct().count()\n",
    "tmp2 = ratings.select('movieID').distinct().count()\n",
    "print('We have a total of {} distinct movies that are rated by users in ratings table'.format(tmp2))\n",
    "print('We have {} movies that are not rated yet'.format(tmp1-tmp2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "List movies that are not rated yet: \n",
      "+-------+--------------------+\n",
      "|movieId|               title|\n",
      "+-------+--------------------+\n",
      "|  25817|Break of Hearts (...|\n",
      "|  26361|Baby Blue Marine ...|\n",
      "|  27153|Can't Be Heaven (...|\n",
      "|  27433|        Bark! (2002)|\n",
      "|  31945|Always a Bridesma...|\n",
      "|  52696|Thousand and One ...|\n",
      "|  58209|Alex in Wonder (S...|\n",
      "|  60234|   Shock, The (1923)|\n",
      "|  69565|Bling: A Planet R...|\n",
      "|  69834|       Agency (1980)|\n",
      "+-------+--------------------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# create a temp SQL table view for easier query\n",
    "movies.createOrReplaceTempView(\"movies\")\n",
    "ratings.createOrReplaceTempView(\"ratings\")\n",
    "print('List movies that are not rated yet: ')\n",
    "# SQL query (NOTE: WHERE ... NOT IN ... == ... LEFT JOIN ... WHERE ... IS NULL)\n",
    "# Approach 1\n",
    "spark.sql(\n",
    "    \"SELECT movieId, title \"\n",
    "    \"FROM movies \"\n",
    "    \"WHERE movieId NOT IN (SELECT distinct(movieId) FROM ratings)\"\n",
    ").show(10)\n",
    "# Approach 2\n",
    "# spark.sql(\n",
    "#     \"SELECT m.movieId, m.title \"\n",
    "#     \"FROM movies m LEFT JOIN ratings r ON m.movieId=r.movieId \"\n",
    "#     \"WHERE r.movieId IS NULL\"\n",
    "# ).show(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "List all movie genres"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All distinct genres: \n",
      "+------------------+\n",
      "|            genres|\n",
      "+------------------+\n",
      "|             Crime|\n",
      "|           Romance|\n",
      "|          Thriller|\n",
      "|         Adventure|\n",
      "|             Drama|\n",
      "|               War|\n",
      "|       Documentary|\n",
      "|           Fantasy|\n",
      "|           Mystery|\n",
      "|           Musical|\n",
      "|         Animation|\n",
      "|         Film-Noir|\n",
      "|(no genres listed)|\n",
      "|              IMAX|\n",
      "|            Horror|\n",
      "|           Western|\n",
      "|            Comedy|\n",
      "|          Children|\n",
      "|            Action|\n",
      "|            Sci-Fi|\n",
      "+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# define a udf for splitting the genres string\n",
    "splitter = UserDefinedFunction(lambda x: x.split('|'), ArrayType(StringType()))\n",
    "# query\n",
    "print('All distinct genres: ')\n",
    "movies.select(explode(splitter(\"genres\")).alias(\"genres\")).distinct().show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Find out the number of movies for each category"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counts of movies per genre\n",
      "+------------------+-----+\n",
      "|            genres|count|\n",
      "+------------------+-----+\n",
      "|             Drama|24144|\n",
      "|            Comedy|15956|\n",
      "|          Thriller| 8216|\n",
      "|           Romance| 7412|\n",
      "|            Action| 7130|\n",
      "|            Horror| 5555|\n",
      "|       Documentary| 5118|\n",
      "|             Crime| 5105|\n",
      "|(no genres listed)| 4266|\n",
      "|         Adventure| 4067|\n",
      "|            Sci-Fi| 3444|\n",
      "|           Mystery| 2773|\n",
      "|          Children| 2749|\n",
      "|         Animation| 2663|\n",
      "|           Fantasy| 2637|\n",
      "|               War| 1820|\n",
      "|           Western| 1378|\n",
      "|           Musical| 1113|\n",
      "|         Film-Noir|  364|\n",
      "|              IMAX|  197|\n",
      "+------------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print('Counts of movies per genre')\n",
    "movies.select('movieID', explode(splitter(\"genres\")).alias(\"genres\")) \\\n",
    "    .groupby('genres') \\\n",
    "    .count() \\\n",
    "    .sort(desc('count')) \\\n",
    "    .show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Spark ALS based approach for training model\n",
    "1. Reload data\n",
    "2. Split data into train, validation, test\n",
    "3. ALS model selection and evaluation\n",
    "4. Model testing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reload data\n",
    "We will use an RDD-based API from pyspark.mllib to predict the ratings, so let's reload \"ratings.csv\" using sc.textFile and then convert it to the form of (user, item, rating) tuples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(1, 307, 3.5), (1, 481, 3.5), (1, 1091, 1.5)]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load data\n",
    "movie_rating = sc.textFile(os.path.join(data_path, 'ratings.csv'))\n",
    "# preprocess data -- only need [\"userId\", \"movieId\", \"rating\"]\n",
    "header = movie_rating.take(1)[0]\n",
    "rating_data = movie_rating \\\n",
    "    .filter(lambda line: line!=header) \\\n",
    "    .map(lambda line: line.split(\",\")) \\\n",
    "    .map(lambda tokens: (int(tokens[0]), int(tokens[1]), float(tokens[2]))) \\\n",
    "    .cache()\n",
    "# check three rows\n",
    "rating_data.take(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split data\n",
    "Now we split the data into training/validation/testing sets using a 6/2/2 ratio."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PythonRDD[166] at RDD at PythonRDD.scala:52"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train, validation, test = rating_data.randomSplit([6, 2, 2], seed=99)\n",
    "# cache data\n",
    "train.cache()\n",
    "validation.cache()\n",
    "test.cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ALS model selection and evaluation\n",
    "With the ALS model, we can use a grid search to find the optimal hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_ALS(train_data, validation_data, num_iters, reg_param, ranks):\n",
    "    \"\"\"\n",
    "    Grid Search Function to select the best model based on RMSE of hold-out data\n",
    "    \"\"\"\n",
    "    # initial\n",
    "    min_error = float('inf')\n",
    "    best_rank = -1\n",
    "    best_regularization = 0\n",
    "    best_model = None\n",
    "    for rank in ranks:\n",
    "        for reg in reg_param:\n",
    "            # train ALS model\n",
    "            model = ALS.train(\n",
    "                ratings=train_data,    # (userID, productID, rating) tuple\n",
    "                iterations=num_iters,\n",
    "                rank=rank,\n",
    "                lambda_=reg,           # regularization param\n",
    "                seed=99)\n",
    "            # make prediction\n",
    "            valid_data = validation_data.map(lambda p: (p[0], p[1]))\n",
    "            predictions = model.predictAll(valid_data).map(lambda r: ((r[0], r[1]), r[2]))\n",
    "            # get the rating result\n",
    "            ratesAndPreds = validation_data.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)\n",
    "            # get the RMSE\n",
    "            MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean()\n",
    "            error = math.sqrt(MSE)\n",
    "            print('{} latent factors and regularization = {}: validation RMSE is {}'.format(rank, reg, error))\n",
    "            if error < min_error:\n",
    "                min_error = error\n",
    "                best_rank = rank\n",
    "                best_regularization = reg\n",
    "                best_model = model\n",
    "    print('\\nThe best model has {} latent factors and regularization = {}'.format(best_rank, best_regularization))\n",
    "    return best_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8 latent factors and regularization = 0.001: validation RMSE is 0.8786891340427128\n",
      "8 latent factors and regularization = 0.01: validation RMSE is 0.8505897623504987\n",
      "8 latent factors and regularization = 0.05: validation RMSE is 0.8180916431113182\n",
      "8 latent factors and regularization = 0.1: validation RMSE is 0.8191175872308086\n",
      "8 latent factors and regularization = 0.2: validation RMSE is 0.8649215618478324\n",
      "10 latent factors and regularization = 0.001: validation RMSE is 0.8881965019270155\n",
      "10 latent factors and regularization = 0.01: validation RMSE is 0.8549168385580616\n",
      "10 latent factors and regularization = 0.05: validation RMSE is 0.8175174972199897\n",
      "10 latent factors and regularization = 0.1: validation RMSE is 0.8196465380445364\n",
      "10 latent factors and regularization = 0.2: validation RMSE is 0.8654977750796198\n",
      "12 latent factors and regularization = 0.001: validation RMSE is 0.9003310016924109\n",
      "12 latent factors and regularization = 0.01: validation RMSE is 0.8639878306627452\n",
      "12 latent factors and regularization = 0.05: validation RMSE is 0.819756733762523\n",
      "12 latent factors and regularization = 0.1: validation RMSE is 0.8215426348742352\n",
      "12 latent factors and regularization = 0.2: validation RMSE is 0.8669398031872779\n",
      "14 latent factors and regularization = 0.001: validation RMSE is 0.9084990840999155\n",
      "14 latent factors and regularization = 0.01: validation RMSE is 0.8665118066902521\n",
      "14 latent factors and regularization = 0.05: validation RMSE is 0.8185366551968787\n",
      "14 latent factors and regularization = 0.1: validation RMSE is 0.8217444486510747\n",
      "14 latent factors and regularization = 0.2: validation RMSE is 0.8674125542187208\n",
      "16 latent factors and regularization = 0.001: validation RMSE is 0.9168763335467992\n",
      "16 latent factors and regularization = 0.01: validation RMSE is 0.8712944863722764\n",
      "16 latent factors and regularization = 0.05: validation RMSE is 0.8180687013058069\n",
      "16 latent factors and regularization = 0.1: validation RMSE is 0.8211544150163288\n",
      "16 latent factors and regularization = 0.2: validation RMSE is 0.8672161234482479\n",
      "18 latent factors and regularization = 0.001: validation RMSE is 0.9247458796091963\n",
      "18 latent factors and regularization = 0.01: validation RMSE is 0.8761932525602091\n",
      "18 latent factors and regularization = 0.05: validation RMSE is 0.8161380007132761\n",
      "18 latent factors and regularization = 0.1: validation RMSE is 0.8204817355743963\n",
      "18 latent factors and regularization = 0.2: validation RMSE is 0.8671101536166931\n",
      "20 latent factors and regularization = 0.001: validation RMSE is 0.9344471859253237\n",
      "20 latent factors and regularization = 0.01: validation RMSE is 0.8812261920317113\n",
      "20 latent factors and regularization = 0.05: validation RMSE is 0.8155352884762789\n",
      "20 latent factors and regularization = 0.1: validation RMSE is 0.818405855491128\n",
      "20 latent factors and regularization = 0.2: validation RMSE is 0.866198213276277\n",
      "\n",
      "The best model has 20 latent factors and regularization = 0.05\n",
      "Total Runtime: 2024.12 seconds\n"
     ]
    }
   ],
   "source": [
    "# hyper-param config\n",
    "num_iterations = 10\n",
    "ranks = [8, 10, 12, 14, 16, 18, 20]\n",
    "reg_params = [0.001, 0.01, 0.05, 0.1, 0.2]\n",
    "\n",
    "# grid search and select best model\n",
    "start_time = time.time()\n",
    "final_model = train_ALS(train, validation, num_iterations, reg_params, ranks)\n",
    "\n",
    "print ('Total Runtime: {:.2f} seconds'.format(time.time() - start_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ALS model learning curve\n",
    "As we increase number of iterations in training ALS, we can see how RMSE changes and whether or not model is overfitted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_learning_curve(arr_iters, train_data, validation_data, reg, rank):\n",
    "    \"\"\"\n",
    "    Plot function to show learning curve of ALS\n",
    "    \"\"\"\n",
    "    errors = []\n",
    "    for num_iters in arr_iters:\n",
    "        # train ALS model\n",
    "        model = ALS.train(\n",
    "            ratings=train_data,    # (userID, productID, rating) tuple\n",
    "            iterations=num_iters,\n",
    "            rank=rank,\n",
    "            lambda_=reg,           # regularization param\n",
    "            seed=99)\n",
    "        # make prediction\n",
    "        valid_data = validation_data.map(lambda p: (p[0], p[1]))\n",
    "        predictions = model.predictAll(valid_data).map(lambda r: ((r[0], r[1]), r[2]))\n",
    "        # get the rating result\n",
    "        ratesAndPreds = validation_data.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)\n",
    "        # get the RMSE\n",
    "        MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean()\n",
    "        error = math.sqrt(MSE)\n",
    "        # add to errors\n",
    "        errors.append(error)\n",
    "\n",
    "    # plot\n",
    "    plt.figure(figsize=(12, 6))\n",
    "    plt.plot(arr_iters, errors)\n",
    "    plt.xlabel('number of iterations')\n",
    "    plt.ylabel('RMSE')\n",
    "    plt.title('ALS Learning Curve')\n",
    "    plt.grid(True)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAtQAAAGDCAYAAAALTociAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xt4HAd57/Hfu7uWZEm2rPEd33ZJUpOQkNixVvQQQJRSrqehLeXW0rSHNoXSBmj7tPTyUMrhtPQcTml5KNCUUqCFUE64FGgLpCSC0ILlOImDnYBj8N2OnVi2bFm2Lrvv+WNH0mqt1W21mtHu9/Owz8593l3NQ347fmfG3F0AAAAA5iYRdQEAAADAYkagBgAAACpAoAYAAAAqQKAGAAAAKkCgBgAAACpAoAYAAAAqQKAGgDpiZn9oZh+Nug4AqCUEagCQZGbdZnbWzBpLpn/czN5TZp1bzexhMztvZk+Z2b1mlimzbNntLCR3/zN3/9VqbNsK7jCzvWZ20cyOmdn/M7MbqrE/AIgLAjWAumdmaUnPleSSfnqG61wt6ZOSfkdSm6SMpL+RlKtKkTOrKRXVvkN/Lemtku6QFEj6MUlflPTy2W4oBp8FAGaMQA0A0i9J+q6kj0u6bYbr3CTpoLt/wwsuuPvn3P3IbHduZs8ws3vMrNfMfmBmry6a93Izeyg8C37UzN5VNC9tZm5mbzSzI5LuLZp2m5kdCc+c/1HROu8ys38qWb/cskvN7BPhmfvHzOz3zOxYmc9wjaS3SHqdu9/r7oPuPuDun3L394bLdJvZrxat88tm9u2icTezt5jZ45IeN7MPm9n7SvbzL2b22+Hw08zsc2b2pJkdNLM7ZvvdA8B8IFADQCFQfyp8vdjM1s5gnQclPcPM3m9mLzCz1rns2MxaJN0j6dOS1kh6raQPmdl14SIXw/pWqHCm981m9sqSzTxf0rWSXlw07RZJWyW9UNI7zezaKcoot+yfSEpLerqkF0n6xSm28UJJx9y9Z4plZuKVkjolXSfpLkmvMTOTJDNrl/RTkj5jZglJX5a0R9KGcP9vM7MXT7pVAKgiAjWAumZmt0jaIumz7r5b0g8lvX669dz9R5K6VAhzn5X0VNgnPdtg/QpJh9z9H9x9xN0fkvQ5ST8f7qfb3b/n7nl3f0SFkPn8km28y90vuvuloml/6u6X3H2PCqHzxilqKLfsqyX9mbufdfdjkj4wxTZWSjo5w888lT93997ws9yvQhvOc8N5r5L0HXc/IalD0mp3f7e7D4V/j79T4QcJACwoAjWAenebpK+7+1Ph+Kc1w7YPd/+uu7/a3VerEPqeJ+mPplmt1BZJnWZ2bvQl6RckrZMkM+s0s/vCtoY+SW+StKpkG0cn2e4TRcMDkqYK+uWWfVrJtifbz6gzktZPMX+mxvbh7i7pM5JeF056vQr/iiAVvrenlXxvfyhpJv+6AADzios+ANQtM1uqwlnYpJmNhspGSSvM7MbwjO2MuPsuM/u8pOtnWcZRSd909xeVmf9pSR+U9FJ3v2xmf6UrA7XPcp8zdVLSRkmPhuObplj2G5L+xsx2uPsDZZa5KKm5aHzdJMuUfpa7JH3dzN6rQivIz4TTj6rQw37NFDUBwILgDDWAevZKFe7KcZ0KFxnepEIv8v0q9C2PSppZU9GrwcxuMbNfM7M1UuHCQhXuEPLdKfZ3xXYkfUXSj5nZG8xsSfjqKOpjXiapNwzTWc2gHWUefVbSH5hZu5ltkPSb5RZ098clfUjSXWbWFX5HTWb2WjN7R7jYw5J+1syaw7ukvHG6AsIWmKckfVTS19z9XDirR9IFM/v98OLJpJldb2Ydc/+4ADA3BGoA9ew2Sf/g7kfc/YnRlwpnhH+h6NZt75B0qeh1r6RzKgTo75lZv6SvSvqCpP89xf6u2I67X1DhQrvXSjqhQvvFX6hwplySfkPSu83sgqR3qhByF8q7JR2TdFDSf0i6W9LgFMvfocJ39zcqfD8/VOGM8pfD+e+XNCTplKRPaLx9YzqflvST4bskyd1zKvSf3xTWNxq622a4TQCYN1ZoUQMAYGpm9mZJr3X30osiAaCucYYaADApM1tvZs8xs4SZbVXhITZfiLouAIgbLkoEAJTTIOlvVXgK5DkV7rjxoUgrAoAYouUDAAAAqAAtHwAAAEAFCNQAAABABRZdD/WqVas8nU5HXUbdu3jxolpaWqIuAzHF8YFyODZQDscGyony2Ni9e/dT4dNwp7ToAnU6ndYDD5R7CBcWSnd3t7q6uqIuAzHF8YFyODZQDscGyony2DCzwzNZjpYPAAAAoAIEagAAAKACBGoAAACgAgRqAAAAoAIEagAAAKACBGoAAACgAgRqAAAAoAIEagAAAKACBGoAAACgAgRqAAAAoAIEagAAAKACBOoZuDyc0zf3P6kn+i5HXQoAAABihkA9A09eGNRtH+vRPY8+EXUpAAAAiBkC9QxsbF+qdcubtPNgb9SlAAAAIGYI1DNgZspmAu061Ct3j7ocAAAAxAiBeoaymUCnzg/qSO9A1KUAAAAgRgjUM9SZCSSJtg8AAABMQKCeoavXtCpoaVAPgRoAAABFCNQzZGbasaWdQA0AAIAJCNSzkM0EOtI7wP2oAQAAMIZAPQudmZWSpJ5DnKUGAABAAYF6Fq5dv0ytjSn1HDwTdSkAAACICQL1LKSSCd1MHzUAAACKEKhnKZsJtP9Uv85eHIq6FAAAAMQAgXqWsuH9qHfRRw0AAAARqGftWRvb1JBK0PYBAAAASQTqWWtMJbVt0wru9AEAAABJBOo56cwE2nu8T/2DI1GXAgAAgIgRqOegIxMo79Luw2ejLgUAAAARq1qgNrMmM+sxsz1mts/M/nSSZRrN7J/N7ICZ7TSzdLXqmU/bN7crmTDtoo8aAACg7lXzDPWgpJ9w9xsl3STpJWb27JJl3ijprLtfLen9kv6iivXMm5bGlK7f0MaFiQAAAKheoPaC/nB0SfjyksVulfSJcPhuSS80M6tWTfOpMxPo4aPndHk4F3UpAAAAiFCqmhs3s6Sk3ZKulvQ37r6zZJENko5KkruPmFmfpJWSnirZzu2SbpektWvXqru7u5plz8jS/hEN5fL6xJe7tTVIRl3Oguvv74/F3wHxxPGBcjg2UA7HBspZDMdGVQO1u+ck3WRmKyR9wcyud/e9c9jOnZLulKQdO3Z4V1fX/BY6BzcNDOmvH7xHwys2q6vrmqjLWXDd3d2Kw98B8cTxgXI4NlAOxwbKWQzHxoLc5cPdz0m6T9JLSmYdl7RJkswsJalN0pmFqKlSK5ob9Ix1y7STPmoAAIC6Vs27fKwOz0zLzJZKepGk75cs9iVJt4XDr5J0r7uX9lnHVjYTaPfhsxrJ5aMuBQAAABGp5hnq9ZLuM7NHJO2SdI+7f8XM3m1mPx0u8/eSVprZAUm/LekdVaxn3mUzgQaGctp34nzUpQAAACAiVeuhdvdHJG2bZPo7i4YvS/r5atVQbdl0IEnqOdirGzetiLgaAAAARIEnJVZgzfImpVc2q+cQfdQAAAD1ikBdoWwm0K5DvcrnF03rNwAAAOYRgbpC2cxKnRsY1uOn+6dfGAAAADWHQF2hzsxoH/WiuNsfAAAA5hmBukIb25dqfVsT96MGAACoUwTqCpmZOtKFPupFdAttAAAAzBMC9TzIZgKdOj+oI70DUZcCAACABUagngejfdS0fQAAANQfAvU8uHpNq4KWBvUQqAEAAOoOgXoeFPqo2wnUAAAAdYhAPU860oGO9A7oib7LUZcCAACABUSgniedmZWSxGPIAQAA6gyBep5cu36ZWhtTPOAFAACgzhCo50kqmdDNW+ijBgAAqDcE6nmUzQTaf6pfZy8ORV0KAAAAFgiBeh5lw/tR76KPGgAAoG4QqOfRsza2qSGVoO0DAACgjhCo51FjKqltm1Zwpw8AAIA6QqCeZ52ZQHuP96l/cCTqUgAAALAACNTzrCMTKO/S7sNnoy4FAAAAC4BAPc+2b25XMmHaRR81AABAXSBQz7OWxpSu39DGhYkAAAB1gkBdBZ2ZQA8fPafLw7moSwEAAECVEairIJsONJTLa8/Rc1GXAgAAgCojUFfBjnS7JNH2AQAAUAcI1FWworlBz1i3jPtRAwAA1AECdZVkM4F2Hz6rkVw+6lIAAABQRQTqKslmAg0M5bTvxPmoSwEAAEAVEairJJsOJNFHDQAAUOsI1FWyZnmT0iubtZNADQAAUNMI1FWUzQR64HCv8nmPuhQAAABUCYG6irKZlTo3MKzHT/dHXQoAAACqhEBdRZ2Z0T7qMxFXAgAAgGohUFfRxvalWt/WRB81AABADSNQV5GZqSMdqOdgr9zpowYAAKhFBOoqy2YCnb4wqCO9A1GXAgAAgCogUFfZaB81bR8AAAC1iUBdZVevaVXQ0sADXgAAAGoUgbrKCn3U7QRqAACAGkWgXgAd6UBHegf0RN/lqEsBAADAPCNQL4DOzEpJUs8hzlIDAADUGgL1Arh2/TK1NqZ4wAsAAEANIlAvgFQyoZu30EcNAABQiwjUCySbCbT/VL96Lw5FXQoAAADmEYF6gWTD+1Hvoo8aAACgphCoF8izNrapIZXQLto+AAAAagqBeoE0ppLatmkFd/oAAACoMQTqBdSZCbT3eJ/6B0eiLgUAAADzhEC9gLKZlcq7tPvw2ahLAQAAwDwhUC+gbZtXKJkw+qgBAABqCIF6AbU0pnT9hjbuRw0AAFBDCNQLrDMT6OGj53R5OBd1KQAAAJgHBOoFlk0HGsrltefouahLAQAAwDwgUC+wjnQgM9H2AQAAUCOqFqjNbJOZ3Wdmj5rZPjN76yTLdJlZn5k9HL7eWa164qKteYm2rl3G/agBAABqRKqK2x6R9Dvu/qCZLZO028zucfdHS5a7391fUcU6YiebCXT37mMayeWVSvKPBAAAAItZ1dKcu5909wfD4QuSHpO0oVr7W0yymUADQzntO3E+6lIAAABQoQU5PWpmaUnbJO2cZPaPm9keM/t3M3vmQtQTtWw6kEQfNQAAQC0wd6/uDsxaJX1T0v9y98+XzFsuKe/u/Wb2Mkl/7e7XTLKN2yXdLklr1669+TOf+UxVa14Iv/+tAT2tNaG3bm+KupQ56e/vV2tra9RlIKY4PlAOxwbK4dhAOVEeGy94wQt2u/uO6ZaraqA2syWSviLpa+7+lzNY/pCkHe7+VLllduzY4Q888MD8FRmR37t7j77+6Ck9+McvUiJhUZcza93d3erq6oq6DMQUxwfK4dhAORwbKCfKY8PMZhSoq3mXD5P095IeKxemzWxduJzMLBvWc6ZaNcVJNrNS5waG9fjp/qhLAQAAQAWqeZeP50h6g6TvmdnD4bQ/lLRZktz9I5JeJenNZjYi6ZKk13q1e1BiojMz2kd9RlvXLYu4GgAAAMxV1QK1u39b0pS9DO7+QUkfrFYNcbaxfanWtzVp58FeveHH01GXAwAAgDniJsgRMTNlM4F6DvaqTk7KAwAA1CQCdYQ60oFOXxjUkd6BqEsBAADAHBGoIzTaR72T+1EDAAAsWgTqCF29plVBSwMPeAEAAFjECNQRMjN1pNsJ1AAAAIsYgTpi2cxKHekd0BN9l6MuBQAAAHNAoI5YNh3ej/oQZ6kBAAAWIwJ1xK5dv0ytjSn1HKyLB0QCAADUHAJ1xFLJhG7eQh81AADAYkWgjoFsJtD+U/3qvTgUdSkAAACYJQJ1DIzej3oXfdQAAACLDoE6Bm7Y2KaGVEK7aPsAAABYdAjUMdCYSmrbphXc6QMAAGARIlDHRGcm0N7jfeofHIm6FAAAAMwCgTomspmVyru0+/DZqEsBAADALBCoY2L7lhVKJYz7UQMAACwyBOqYaG5I6Zkb2rTrIGeoAQAAFhMCdYx0ZgI9fPScLg/noi4FAAAAM0SgjpFsOtBQLq89R89FXQoAAABmiEAdIx3pQGbiMeQAAACLCIE6Rtqal2jr2mXcjxoAAGARIVDHTDYTaPfhsxrJ5aMuBQAAADNAoI6ZbCbQwFBO+06cj7oUAAAAzACBOmay6UASfdQAAACLBYE6ZtYsb1JmVYt2EqgBAAAWBQJ1DGXTgXYd6lU+71GXAgAAgGkQqGOoIxOo79KwHj/dH3UpAAAAmAaBOoY6M6N91GcirgQAAADTIVDH0Mb2pVrf1kQfNQAAwCJAoI4hM1M2E6jnYK/c6aMGAACIMwJ1THWkA52+MKjDZwaiLgUAAABTIFDH1FgfNY8hBwAAiDUCdUxdvaZVQUsDD3gBAACIOQJ1TJmZOtLtBGoAAICYI1DHWDazUkd6B3Sy71LUpQAAAKAMAnWMZdOj96PmLDUAAEBcEahj7Nr1y9TamNIuLkwEAACILQJ1jKWSCd28hT5qAACAOCNQx1w2E2j/qX71XhyKuhQAAABMgkAdc6P3o6btAwAAIJ4I1DF3w8Y2NaQStH0AAADEFIE65hpTSW3btIIz1AAAADFFoF4EOjOB9h7vU//gSNSlAAAAoASBehHIZlYq79Luw2ejLgUAAAAlCNSLwPYtK5RKmHoOnom6FAAAAJQgUC8CzQ0pPXNDm3Yd5Aw1AABA3BCoF4nOTKCHj57T5eFc1KUAAACgCIF6kcimAw3l8tpz9FzUpQAAAKDIlIHazH6iaDhTMu9nq1UUrtSRDmQm7kcNAAAQM9OdoX5f0fDnSub98TzXgim0NS/R1rXL1MP9qAEAAGJlukBtZYYnG0eVZTOBdh8+q5FcPupSAAAAEJouUHuZ4cnGUWXZTKCBoZz2nTgfdSkAAAAIpaaZ/3Qz+5IKZ6NHhxWOZ8qvhmrIpgNJhT7qGzetiLgaAAAASNMH6luLht9XMq90HFW2ZnmTMqtatPNgr37teU+PuhwAAABomkDt7t8sHjezJZKul3Tc3U9XszBMLpsO9NV9TyifdyUStLEDAABEbbrb5n3EzJ4ZDrdJ2iPpk5IeMrPXTbPuJjO7z8weNbN9ZvbWSZYxM/uAmR0ws0fMbHsFn6UudGQC9V0a1uOn+6MuBQAAAJr+osTnuvu+cPhXJO139xsk3Szp96ZZd0TS77j7dZKeLektZnZdyTIvlXRN+Lpd0odnU3w96syM9lGfibgSAAAASNMH6qGi4RdJ+qIkufsT023Y3U+6+4Ph8AVJj0naULLYrZI+6QXflbTCzNbPtPh6tLF9qda3NWknD3gBAACIhekuSjxnZq+QdFzScyS9UZLMLCVp6Ux3YmZpSdsk7SyZtUHS0aLxY+G0kyXr367CGWytXbtW3d3dM911TdrSPKxv/+AJ3XfffTKLpo+6v7+/7v8OKI/jA+VwbKAcjg2UsxiOjekC9a9L+oCkdZLeVnRm+oWS/nUmOzCzVhWesvg2d5/TDZTd/U5Jd0rSjh07vKuray6bqRnHlx7Wd7+wV5kbskqvaomkhu7ubtX73wHlcXygHI4NlMOxgXIWw7Ex3V0+9kt6ySTTvybpa9NtPLwryOckfcrdPz/JIsclbSoa3xhOwxTG7kd9qDeyQA0AAICCKQO1mX1gqvnufscU65qkv5f0mLv/ZZnFviTpN83sM5I6JfW5+8kyyyJ09ZpWBS0N6jnYq1fv2DT9CgAAAKia6Vo+3iRpr6TPSjqhwhMSZ+o5kt4g6Xtm9nA47Q8lbZYkd/+IpH+T9DJJByQNqHAnEUzDzNSRblcPFyYCAABEbrpAvV7Sz0t6jQq3wftnSXe7+7npNuzu39Y0AdzdXdJbZlYqimUzK/W1fad0su+S1rfN+PpQAAAAzLMpb5vn7mfc/SPu/gIVzh6vkPSomb1hQapDWeP3o+YsNQAAQJSmuw+1JCl8guFbJf2ipH+XtLuaRWF6165frtbGlHYdIlADAABEabqLEt8t6eUqPJTlM5L+wN1HFqIwTC2ZMN28hT5qAACAqE13hvqPVWjzuFHSn0t60MweMbPvmdkjVa8OU8pmAu0/1a/ei0PTLwwAAICqmO6ixMyCVIE5Ge2j3nWoVy9+5rqIqwEAAKhP012UeHiylwqPC79lYUpEOTdsbFNjKkHbBwAAQISmDNRmttzM/sDMPmhmP2UFvyXpR5JevTAlopzGVFI3bVrBhYkAAAARmq6H+h8lbZX0PUm/Kuk+Sa+S9Ep3v7XKtWEGOjOB9h7vU/8g14oCAABEYboe6qe7+w2SZGYflXRS0mZ3v1z1yjAj2cxK5e89oN2Hz+r5P7Y66nIAAADqznRnqIdHB9w9J+kYYTpetm9ZoVTC1HPwTNSlAAAA1KXpzlDfaGbnw2GTtDQcNxWeHL68qtVhWs0NKV2/oY0LEwEAACIy3V0+ku6+PHwtc/dU0TBhOiaymUB7jvbp8nAu6lIAAADqzowePY54y6YDDeXy2nP0XNSlAAAA1B0CdQ3oSAcyE20fAAAAESBQ14C25iXaunaZergfNQAAwIIjUNeIzkyg3YfPaiSXj7oUAACAukKgrhEdmUADQzntO3F++oUBAAAwbwjUNSKbDiTRRw0AALDQCNQ1Ys3yJmVWtWgngRoAAGBBEahrSDYdaNehXuXzHnUpAAAAdYNAXUOymUB9l4a1//SFqEsBAACoGwTqGpLNFPqod9H2AQAAsGAI1DVkY/tSrW9roo8aAABgARGoa4iZKZsJ1HOwV+70UQMAACwEAnWNyWYCnb4wqMNnBqIuBQAAoC4QqGtMZ4b7UQMAACwkAnWNuWp1q4KWBvUcIlADAAAsBAJ1jTEzdaTbOUMNAACwQAjUNSibWakjvQM62Xcp6lIAAABqHoG6BtFHDQAAsHAI1DXo2vXL1dqYIlADAAAsAAJ1DUomTDdvadcuLkwEAACoOgJ1jcpmAu0/1a/ei0NRlwIAAFDTCNQ1arSPmrPUAAAA1UWgrlE3bGxTYypBHzUAAECVEahrVGMqqW2bV3CGGgAAoMoI1DUsmw6093if+gdHoi4FAACgZhGoa1g2s1J5l3YfPht1KQAAADWLQF3Dtm9ZoVTC1HPwTNSlAAAA1CwCdQ1rbkjp+g1tXJgIAABQRQTqGteZCbTnaJ8uD+eiLgUAAKAmEahrXEc60FAurz1Hz0VdCgAAQE0iUNe4jnQgM9H2AQAAUCUE6hrX1rxEW9cuUw/3owYAAKgKAnUd6MwE2n34rEZy+ahLAQAAqDkE6jrQkQk0MJTTvhPnoy4FAACg5hCo60A2HUiijxoAAKAaCNR1YM3yJmVWtWgngRoAAGDeEajrRDYdaNehXuXzHnUpAAAANYVAXSeymUB9l4a1//SFqEsBAACoKQTqOpHNFPqod9H2AQAAMK8I1HViY/tSrW9roo8aAABgnhGo64SZKZsJ1HOwV+70UQMAAMwXAnUdyWYCnb4wqMNnBqIuBQAAoGZULVCb2cfM7LSZ7S0zv8vM+szs4fD1zmrVgoLODPejBgAAmG/VPEP9cUkvmWaZ+939pvD17irWAklXrW5V0NKgnkMEagAAgPlStUDt7t+SRHKLETNTR7qdM9QAAADzKBXx/n/czPZIOiHpd91932QLmdntkm6XpLVr16q7u3vhKqwxQW5YR3qH9Pmv3qugae6/p/r7+/k7oCyOD5TDsYFyODZQzmI4NqIM1A9K2uLu/Wb2MklflHTNZAu6+52S7pSkHTt2eFdX14IVWWtWHe/TXd//tpLrtqrrpg1z3k53d7f4O6Acjg+Uw7GBcjg2UM5iODYiu8uHu5939/5w+N8kLTGzVVHVUy+uXb9crY0p2j4AAADmSWSB2szWmZmFw9mwljNR1VMvkgnTzVvatYsLEwEAAOZFNW+bd5ek70jaambHzOyNZvYmM3tTuMirJO0Ne6g/IOm1zhNHFkQ2E2j/qX71XhyKuhQAAIBFr2o91O7+umnmf1DSB6u1f5Q3ej/qXYd69eJnrou4GgAAgMWNJyXWoRs2tqkxlaCPGgAAYB4QqOtQYyqpbZtXEKgBAADmAYG6TmXTgfad6FP/4EjUpQAAACxqBOo6lc2sVN6l3YfPRl0KAADAokagrlPbt6xQKmHqOcidCgEAACpBoK5TzQ0pXb+hjT5qAACAChGo61hnJtCeo326PJyLuhQAAIBFi0BdxzrSgYZyee05ei7qUgAAABYtAnUd60gHMhNtHwAAABUgUNextuYl2rp2mXoOEagBAADmikBd5zozgXYfPqvhXD7qUgAAABYlAnWdy2ZWamAop30nzkddCgAAwKJEoK5zHZl2SdIu+qgBAADmhEBd59Ysa1JmVYt2EqgBAADmhEANZdOBdh3qVT7vUZcCAACw6BCooWwmUN+lYe0/fSHqUgAAABYdAjWUzQSSuB81AADAXBCooY3tS7W+rYlADQAAMAcEasjMlM0E6jnYK3f6qAEAAGaDQA1JhbaP0xcGdfjMQNSlAAAALCoEakgqPDFRoo8aAABgtgjUkCRdtbpVQUsD96MGAACYJQI1JBX6qDvS7dp1iEANAAAwGwRqjMlmVupI74BO9l2KuhQAAIBFg0CNMfRRAwAAzB6BGmOuXb9crY0pAjUAAMAsEKgxJpkw7aCPGgAAYFYI1JigIx1o/6l+9V4ciroUAACARYFAjQlG+6g5Sw0AADAzBGpMcMPGNjWmEvRRAwAAzBCBGhM0ppLatnkFgRoAAGCGCNS4QjazUvtO9Kl/cCTqUgAAAGKPQI0rZNOB8i7tPnw26lIAAABij0CNK2zfskKphKnn4JmoSwEAAIg9AjWu0NyQ0vUb2uijBgAAmAECNSbVmQm052ifLg/noi4FAAAg1gjUmFQ2E2gol9eeo+eiLgUAACDWCNSY1I4tgcxE2wcAAMA0CNSYVFvzEm1du0w9PDERAABgSgRqlNWZCbT78FkN5/JRlwIAABBbBGqUlc2s1MBQTvtOnI+6FAAAgNgiUKOsjky7JGkXfdQAAABlEahR1pplTcqsatFOAjUAAEBZBGpMKZsOtOtQr/J5j7oUAACAWCJQY0rZTKC+S8Paf/pC1KUAAADEEoEaU8pmAkncjxoAAKAcAjWmtLF9qZ7W1kSgBgAAKINAjSmZmToygXoO9sqdPmoAAIBSBGpMK5sJdPrCoA6fGYi6FAAAgNghUGNanfRRAwAAlEWgxrReePS2AAAQ30lEQVSuWt2qoKWB+1EDAABMgkCNaZnZ2P2oAQAAMBGBGjPSkQl0pHdAJ/suRV0KAABArBCoMSP0UQMAAEyuaoHazD5mZqfNbG+Z+WZmHzCzA2b2iJltr1YtqNy165ertTFFoAYAAChRzTPUH5f0kinmv1TSNeHrdkkfrmItqFAyYdqRbidQAwAAlKhaoHb3b0maKn3dKumTXvBdSSvMbH216kHlsplAj5/uV+/FoahLAQAAiI0oe6g3SDpaNH4snIaYyqYLfdTc7QMAAGBcKuoCZsLMblehLURr165Vd3d3tAXVqeG8a0lC+vy39ui/bxrm74Cy+vv7OT4wKY4NlMOxgXIWw7ERZaA+LmlT0fjGcNoV3P1OSXdK0o4dO7yrq6vqxWFyNx/4jk4M5tTaOiL+Diinu7ub4wOT4thAORwbKGcxHBtRtnx8SdIvhXf7eLakPnc/GWE9mIFsZqX2nejTpRGPuhQAAIBYqNoZajO7S1KXpFVmdkzSn0haIknu/hFJ/ybpZZIOSBqQ9CvVqgXzJ5sOlHfpwNlc1KUAAADEQtUCtbu/bpr5Lukt1do/qmP7lhVKJUw/OJuPuhQAAIBY4EmJmJXmhpSu39CmH/RyhhoAAEBaJHf5QLx0Pj3Q337znHa85x5dtbpVV69pHXu/ek2r1rc1ycyiLhMAAGBBEKgxa7/+vKt07olj8mVr9MMnL+orj5xU36XhsfnNDcmioN0yFrQ3By1qSPGPIgAAoLYQqDFrQUuDXppZoq6uGyVJ7q6n+of0wyf7deB04fXDJ/u180dn9IWHxu+EmEqYNq9sHj+bvbpVV4Whe1nTkqg+DgAAQEUI1KiYmWn1skatXtaoZz995YR5FwdH9KMnL+rAkxcKQfv0RR14sl/3ff+0RvLjt95bt7xJV61p0dUlLSSrlzXSPgIAAGKNQI2qamlM6YaNbbphY9uE6cO5vI70Dkw4o/3D0/363IPH1T84MrbcsqbUhIA9+r6pfalSSdpHAABA9AjUiMSSZEJXrS4E5Bc/c3y6u+vU+cEwaF/QD5+8qAOn+/XN/U/q7t3HxpZrSCaUXtVc0jpSeC1tSEbwiQAAQL0iUCNWzEzr2pq0rq1Jt1yzasK8vkvDY33ao2e0Hz1xXl/d+4SKuke0YcXSsQshi+8+ErQ0LPCnAQAA9YBAjUWjbekSbd/cru2b2ydMHxzJ6dBTE9tHDpzu186DZ3R5ePwBNO3NSyYE7avCs9sbVixVIkGfNgAAmBsCNRa9xlRSW9ct09Z1yyZMz+ddx89d0oHwbPZo0P7q3id0dmD8Nn9NSxJ6+qorz2inVzWrMUX7CAAAmBqBGjUrkTBtCpq1KWjWC7aumTCv9+LQFWe0dx8+qy/tOTG+vkmbg/Hb/F1VFLjblnKbPwAAUECgRl0KWhqUzQTKZoIJ0y8N5Qr92eFZ7QNh2L7/8ac0lBtvH1m9rHH8oTWrW7VqWaMakgk1LkmG74nCeyqhxlRSDamEGlKF8YZUQqmEcTtAAABqBIEaKLK0IanrN7Tp+g0Tb/M3ksvr6NlLE85oHzjdr3956IQuFN3mb6YSpkLInkUIH31vSCbHl512nWTRfsbfG5PjyybpHwcAoCIEamAGUsmEMqtalFnVohdp7dh0d9eTFwZ17tKwBofzGsrlNDiS1+BIXkNF74XhXNFwXkO5vAaHc4X3K9bJaWBoRGcHwvVz+XD74+sM53yKimfx2RI2gxCeLAruiSmDe0MqoR8eG9aZ3ceUSpqSCVMqYUomEkolR4dNqUQhzC9JThxPJSxcbnw8mTQtKRrnIlIAQJwQqIEKmJnWLG/SmuVNC77vfN7HwnhxYB8L8WMhfHz64HQhf3RaLj9hnfOXhseC/pXbzyuXnyTc791Ttc9uVvghkEokxgJ3cVAvDfJLJoyPL1Ma7senJcbmFdZNFC1rSiWLxsNtLynaxoQfEeEPgKSZEgkpaYX5o9OSCVMifE8mNDY8Pq1oeHQbV0zjBwYARIlADSxSiYSpKZFU05Lo70QykstPCOHf/s//Ukf22RrJF8L2SN41kvMJ47m8azg3cbywXH7CeC4cH5uWc+XyeQ2XjI/vY3y8sI+J4yN518DQyPT7naS+SX84xESyJLQnEuOB26wQ1kunJ4rfi+dbSeBPmJI2Mchf+YOgZP4VPwgK2z9yZEiP5B5XwgrHcMKsMGxFw+E1BgkLazKT2fiPDStaPplQuGzRdhLjw2bjnythhWVLtzn6nVnRNkr3k0hcWWPxPq2kVq6RAOoLgRpAxVLJhFLJhJrDZ+esXJrQ5pXN0RZVBe5eEvYLPxJGSsbHg3xhPO+uXF7K5T0cduXclQ+3NTa/aNrY8IRpGpuWy5fMD7cxtv2ifY2/l6w/tqyuWHa09sGRov1OVr+78uFnm1hLyXwv+kFyYH+0f8gFUBqwrwzpo0F8Yiif8l2TTx/9oXDFuDThx4BUum9dsf/RHyeFdcdrLGx78vErtp0orD/5tos+S2LiZ/vRwWEdSP5orHYrWXfCtHB/Kvzviu9FpdPCdYu/l9F5k25DkkprLbMN02iNE6clTCW1h9OKli+eljCFtUz8G47tI6Gx7UyYXvJ5r1iGH3cLgkANADNkVmgBicE/Cixa9913n573/K6xcO4ehnl3eRjsC6+S4fzE6V70A2LCdsIfDh4um8v72PDYfsKQf+W+FG538v0Ub7N4n2Pr50vrnnxeca0uTajHNXGZcu+jtY4uPzZe8r2N/rgbW2Z0f+E+J4xfsQ3JFS435b4K7yoZH9/2eO3T+sFj1Tv46tikQTv8JTEh5BctM/aDYpJ1VfyDoWRdSUokptmmrvxBMPbjomSbMtOF85f0zJsHtXpZ4wJ+a7NDoAYALBiz8VYQ1J/SgF38/q3779dzbrlFXhTYi39g+FhIL/rxEf6rx4RpY+F9NPwXbUOj25r78l40b+KPkonbcE388TG6rbLTwuVV8oNp9HsoXWd8XtH3M9l0FW1zknVHfwhNtq40yXcRLqOxzz/FNkvrL6w29TYn+ey5gfAMfowRqAEAwIIo9LBLSV2ZjpamTMubeGgWrtTd3a2VrfE9Oy1JiagLAAAAABYzAjUAAABQAQI1AAAAUAECNQAAAFABAjUAAABQAQI1AAAAUAECNQAAAFABAjUAAABQAQI1AAAAUAECNQAAAFABAjUAAABQAQI1AAAAUAECNQAAAFABc/eoa5gVM3tS0uGo64BWSXoq6iIQWxwfKIdjA+VwbKCcKI+NLe6+erqFFl2gRjyY2QPuviPqOhBPHB8oh2MD5XBsoJzFcGzQ8gEAAABUgEANAAAAVIBAjbm6M+oCEGscHyiHYwPlcGygnNgfG/RQAwAAABXgDDUAAABQAQI1ZsXMNpnZfWb2qJntM7O3Rl0T4sXMkmb2kJl9JepaEB9mtsLM7jaz75vZY2b241HXhPgws7eH/03Za2Z3mVlT1DUhGmb2MTM7bWZ7i6YFZnaPmT0evrdHWeNkCNSYrRFJv+Pu10l6tqS3mNl1EdeEeHmrpMeiLgKx89eSvuruz5B0ozhGEDKzDZLukLTD3a+XlJT02mirQoQ+LuklJdPeIekb7n6NpG+E47FCoMasuPtJd38wHL6gwn8UN0RbFeLCzDZKermkj0ZdC+LDzNokPU/S30uSuw+5+7loq0LMpCQtNbOUpGZJJyKuBxFx929J6i2ZfKukT4TDn5D0ygUtagYI1JgzM0tL2iZpZ7SVIEb+StLvScpHXQhiJSPpSUn/ELYDfdTMWqIuCvHg7sclvU/SEUknJfW5+9ejrQoxs9bdT4bDT0haG2UxkyFQY07MrFXS5yS9zd3PR10Pomdmr5B02t13R10LYiclabukD7v7NkkXFcN/skU0wn7YW1X44fU0SS1m9ovRVoW48sLt6WJ3izoCNWbNzJaoEKY/5e6fj7oexMZzJP20mR2S9BlJP2Fm/xRtSYiJY5KOufvov2bdrULABiTpJyUddPcn3X1Y0ucl/beIa0K8nDKz9ZIUvp+OuJ4rEKgxK2ZmKvRBPubufxl1PYgPd/8Dd9/o7mkVLii61905ywS5+xOSjprZ1nDSCyU9GmFJiJcjkp5tZs3hf2NeKC5axURfknRbOHybpH+JsJZJEagxW8+R9AYVzj4+HL5eFnVRAGLvtyR9yswekXSTpD+LuB7ERPgvF3dLelDS91TIJrF/Mh6qw8zukvQdSVvN7JiZvVHSeyW9yMweV+FfNN4bZY2T4UmJAAAAQAU4Qw0AAABUgEANAAAAVIBADQAAAFSAQA0AAABUgEANAAAAVIBADQARMLNuM9uxAPu5w8weM7NPlUzfYWYfCIe7zGzeHqRhZmkze/1k+wKAWpSKugAAwOyYWcrdR2a4+G9I+kl3P1Y80d0fkPRAONolqV/Sf81TDWlJr5f06Un2BQA1hzPUAFBGeKb1MTP7OzPbZ2ZfN7Ol4byxM8xmtip85LrM7JfN7Itmdo+ZHTKz3zSz3zazh8zsu2YWFO3iDeHDkfaaWTZcv8XMPmZmPeE6txZt90tmdq+kb0xS62+H29lrZm8Lp31E0tMl/buZvb1k+S4z+4qZpSW9SdLbw1qea2arzexzZrYrfD0nXOddZvaPZvafkv4x/H7uN7MHw9foWe73SnpuuL23j+4r3EYQfj+PhN/Hs4q2/bHwe/2Rmd1R9H38q5ntCT/bayr7qwLA/OMMNQBM7RpJr3P3XzOzz0r6OUn/NM0610vaJqlJ0gFJv+/u28zs/ZJ+SdJfhcs1u/tNZvY8SR8L1/sjFR7b/j/MbIWkHjP7j3D57ZKe5e69xTszs5sl/YqkTkkmaaeZfdPd32RmL5H0And/arJC3f1QGLz73f194fY+Len97v5tM9ss6WuSrg1XuU7SLe5+ycyaJb3I3S+b2TWS7pK0Q9I7JP2uu78i3F5X0S7/VNJD7v5KM/sJSZ9U4cmJkvQMSS+QtEzSD8zsw5JeIumEu7883FbbNN89ACw4AjUATO2guz8cDu9WoZ1hOve5+wVJF8ysT9KXw+nfk/SsouXukiR3/5aZLQ8D9E9J+mkz+91wmSZJm8Phe0rDdOgWSV9w94uSZGafl/RcSQ/N5ANO4iclXWdmo+PLzaw1HP6Su18Kh5dI+qCZ3SQpJ+nHZrDtW1T4USJ3v9fMVprZ8nDev7r7oKRBMzstaa0K39n/NbO/kPQVd79/jp8JAKqGQA0AUxssGs5JWhoOj2i8ba5pinXyReN5Tfz/XS9Zz1U4w/xz7v6D4hlm1inp4qwqn7uEpGe7++WSGlRSw9slnZJ0Y7jOhOXnoPS7Trn7fjPbLullkt5jZt9w93dXuB8AmFf0UAPA3BySdHM4/Ko5buM1kmRmt0jqc/c+FdorfsvC9Gpm22awnfslvdLMms2sRdLPhNNm6oIKbRajvi7pt0ZHwjPQk2mTdNLd85LeIClZZnultf5CuN0uSU+5+/lyhZnZ0yQNuPs/Sfo/KrS9AECsEKgBYG7eJ+nNZvaQpFVz3MblcP2PSHpjOO1/qtBK8YiZ7QvHp+TuD0r6uKQeSTslfdTdZ9Pu8WVJPzN6UaKkOyTtCC8cfFSFixYn8yFJt5nZHhX6n0fPXj8iKRdeSPj2knXeJelmM3tEhYsXb5umthtU6CN/WNKfSHrPLD4XACwIcy/9F0cAAAAAM8UZagAAAKACBGoAAACgAgRqAAAAoAIEagAAAKACBGoAAACgAgRqAAAAoAIEagAAAKACBGoAAACgAv8fyj6dja3vwO4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 864x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# create an array of num_iters\n",
    "iter_array = list(range(1, 11))\n",
    "# create learning curve plot\n",
    "plot_learning_curve(iter_array, train, validation, 0.05, 20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After 3 iterations, alternating gradient descend starts to converge at an error around 0.8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model testing\n",
    "And finally, make a prediction and check the testing error using out-of-sample data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The out-of-sample RMSE of rating predictions is 0.8156\n"
     ]
    }
   ],
   "source": [
    "# make prediction using test data\n",
    "test_data = test.map(lambda p: (p[0], p[1]))\n",
    "predictions = final_model.predictAll(test_data).map(lambda r: ((r[0], r[1]), r[2]))\n",
    "# get the rating result\n",
    "ratesAndPreds = test.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)\n",
    "# get the RMSE\n",
    "MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean()\n",
    "error = math.sqrt(MSE)\n",
    "print('The out-of-sample RMSE of rating predictions is', round(error, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make movie recommendation to myself\n",
    "We need to define a function that takes new user's movie rating and output top 10 recommendations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_movieId(df_movies, fav_movie_list):\n",
    "    \"\"\"\n",
    "    return all movieId(s) of user's favorite movies\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    df_movies: spark Dataframe, movies data\n",
    "    \n",
    "    fav_movie_list: list, user's list of favorite movies\n",
    "    \n",
    "    Return\n",
    "    ------\n",
    "    movieId_list: list of movieId(s)\n",
    "    \"\"\"\n",
    "    movieId_list = []\n",
    "    for movie in fav_movie_list:\n",
    "        movieIds = df_movies \\\n",
    "            .filter(movies.title.like('%{}%'.format(movie))) \\\n",
    "            .select('movieId') \\\n",
    "            .rdd \\\n",
    "            .map(lambda r: r[0]) \\\n",
    "            .collect()\n",
    "        movieId_list.extend(movieIds)\n",
    "    return list(set(movieId_list))\n",
    "\n",
    "\n",
    "def add_new_user_to_data(train_data, movieId_list, spark_context):\n",
    "    \"\"\"\n",
    "    add new rows with new user, user's movie and ratings to\n",
    "    existing train data\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    train_data: spark RDD, ratings data\n",
    "    \n",
    "    movieId_list: list, list of movieId(s)\n",
    "\n",
    "    spark_context: Spark Context object\n",
    "    \n",
    "    Return\n",
    "    ------\n",
    "    new train data with the new user's rows\n",
    "    \"\"\"\n",
    "    # get new user id\n",
    "    new_id = train_data.map(lambda r: r[0]).max() + 1\n",
    "    # get max rating\n",
    "    max_rating = train_data.map(lambda r: r[2]).max()\n",
    "    # create new user rdd\n",
    "    user_rows = [(new_id, movieId, max_rating) for movieId in movieId_list]\n",
    "    new_rdd = spark_context.parallelize(user_rows)\n",
    "    # return new train data\n",
    "    return train_data.union(new_rdd)\n",
    "\n",
    "\n",
    "def get_inference_data(train_data, df_movies, movieId_list):\n",
    "    \"\"\"\n",
    "    return a rdd with the userid and all movies (except ones in movieId_list)\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    train_data: spark RDD, ratings data\n",
    "\n",
    "    df_movies: spark Dataframe, movies data\n",
    "    \n",
    "    movieId_list: list, list of movieId(s)\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    inference data: Spark RDD\n",
    "    \"\"\"\n",
    "    # get new user id\n",
    "    new_id = train_data.map(lambda r: r[0]).max() + 1\n",
    "    # return inference rdd\n",
    "    return df_movies.rdd \\\n",
    "        .map(lambda r: r[0]) \\\n",
    "        .distinct() \\\n",
    "        .filter(lambda x: x not in movieId_list) \\\n",
    "        .map(lambda x: (new_id, x))\n",
    "\n",
    "\n",
    "def make_recommendation(best_model_params, ratings_data, df_movies, \n",
    "                        fav_movie_list, n_recommendations, spark_context):\n",
    "    \"\"\"\n",
    "    return top n movie recommendation based on user's input list of favorite movies\n",
    "\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    best_model_params: dict, {'iterations': iter, 'rank': rank, 'lambda_': reg}\n",
    "\n",
    "    ratings_data: spark RDD, ratings data\n",
    "\n",
    "    df_movies: spark Dataframe, movies data\n",
    "\n",
    "    fav_movie_list: list, user's list of favorite movies\n",
    "\n",
    "    n_recommendations: int, top n recommendations\n",
    "\n",
    "    spark_context: Spark Context object\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    list of top n movie recommendations\n",
    "    \"\"\"\n",
    "    # modify train data by adding new user's rows\n",
    "    movieId_list = get_movieId(df_movies, fav_movie_list)\n",
    "    train_data = add_new_user_to_data(ratings_data, movieId_list, spark_context)\n",
    "    \n",
    "    # train best ALS\n",
    "    model = ALS.train(\n",
    "        ratings=train_data,\n",
    "        iterations=best_model_params.get('iterations', None),\n",
    "        rank=best_model_params.get('rank', None),\n",
    "        lambda_=best_model_params.get('lambda_', None),\n",
    "        seed=99)\n",
    "    \n",
    "    # get inference rdd\n",
    "    inference_rdd = get_inference_data(ratings_data, df_movies, movieId_list)\n",
    "    \n",
    "    # inference\n",
    "    predictions = model.predictAll(inference_rdd).map(lambda r: (r[1], r[2]))\n",
    "    \n",
    "    # get top n movieId\n",
    "    topn_rows = predictions.sortBy(lambda r: r[1], ascending=False).take(n_recommendations)\n",
    "    topn_ids = [r[0] for r in topn_rows]\n",
    "    \n",
    "    # return movie titles\n",
    "    return df_movies.filter(movies.movieId.isin(topn_ids)) \\\n",
    "                    .select('title') \\\n",
    "                    .rdd \\\n",
    "                    .map(lambda r: r[0]) \\\n",
    "                    .collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's pretend I am a new user in this recommender system. I will input a handful of my all-time favorite movies into the system. And then the system should output top N movie recommendations for me to watch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recommendations for Iron Man:\n",
      "1: Pearl Jam: Immagine in Cornice - Live in Italy 2006 (2007)\n",
      "2: \"Diebuster \"\"Top wo Narae 2\"\" (2004)\"\n",
      "3: Dinosaur Island (1994)\n",
      "4: Melhores do Mundo - Hermanoteu na Terra de Godah (2009)\n",
      "5: Into Pitch Black (2000)\n",
      "6: Kizumonogatari II: Passionate Blood (2016)\n",
      "7: Countdown (2004)\n",
      "8: Heroes Above All (2017)\n",
      "9: Stone Cold Steve Austin: The Bottom Line on the Most Popular Superstar of All Time (2011)\n",
      "10: WWE: Ladies and Gentlemen, My Name Is Paul Heyman (2014)\n"
     ]
    }
   ],
   "source": [
    "# my favorite movies\n",
    "my_favorite_movies = ['Iron Man']\n",
    "\n",
    "# get recommends\n",
    "recommends = make_recommendation(\n",
    "    best_model_params={'iterations': 10, 'rank': 20, 'lambda_': 0.05}, \n",
    "    ratings_data=rating_data, \n",
    "    df_movies=movies, \n",
    "    fav_movie_list=my_favorite_movies, \n",
    "    n_recommendations=10, \n",
    "    spark_context=sc)\n",
    "\n",
    "print('Recommendations for {}:'.format(my_favorite_movies[0]))\n",
    "for i, title in enumerate(recommends):\n",
    "    print('{0}: {1}'.format(i+1, title))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This list of movie recommendations look completely different than the list from my previous **KNN** model recommender. Not only it recommends movies outside of years between 2007 and 2009 periods, but also recommends movies that were less known. So this can offer users some elements of suprise so that users won't get bored by getting the same popular movies all the time.\n",
    "\n",
    "So this list of recommendations can be blended into the previous list of recommendations from **KNN** model recommender"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
